{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample callback\n",
    "\n",
    "This notebook demonstrates the usage of the callback attribute in `pm.sample`. A callback is a function which gets called for every sample from the trace of a chain. The function is called with the trace and the current draw as arguments and will contain all samples for a single trace.\n",
    "\n",
    "The sampling process can be interrupted by throwing a `KeyboardInterrupt` from inside the callback.\n",
    "\n",
    "use-cases for this callback include:\n",
    "\n",
    " - Stopping sampling when a number of effective samples is reached\n",
    " - Stopping sampling when there are too many divergences\n",
    " - Logging metrics to external tools (such as TensorBoard)\n",
    " \n",
    "We'll start with defining a simple model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pymc3 as pm\n",
    "import numpy as np\n",
    "\n",
    "X = np.array([1, 2, 3, 4, 5])\n",
    "y = X * 2 + np.random.randn(len(X))\n",
    "with pm.Model() as model:\n",
    "    \n",
    "    intercept = pm.Normal('intercept', 0, 10)\n",
    "    slope = pm.Normal('slope', 0, 10)\n",
    "    \n",
    "    mean = intercept + slope * X\n",
    "    error = pm.HalfCauchy('error', 1)\n",
    "    obs = pm.Normal('obs', mean, error, observed=y)\n",
    "   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then for example add a callback that stops sampling whenever 100 samples are made, regardless of the number of draws set in the `pm.sample`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Auto-assigning NUTS sampler...\n",
      "Initializing NUTS using jitter+adapt_diag...\n",
      "Sequential sampling (1 chains in 1 job)\n",
      "NUTS: [error, slope, intercept]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='0' class='progress-bar-interrupted' max='500', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      Interrupted\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "There were 12 divergences after tuning. Increase `target_accept` or reparameterize.\n",
      "The acceptance probability does not match the target. It is 0.5303940121554945, but should be close to 0.8. Try to increase the number of tuning steps.\n",
      "Only one chain was sampled, this makes it impossible to run some convergence checks\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def my_callback(trace, draw):\n",
    "    if len(trace) >= 100:\n",
    "        raise KeyboardInterrupt()\n",
    "    \n",
    "with model:\n",
    "    trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)\n",
    "    \n",
    "print(len(trace))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Something to note though, is that the trace we get passed in the callback only correspond to a single chain. That means that if we want to do calculations over multiple chains at once, we'll need a bit of machinery to make this possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Auto-assigning NUTS sampler...\n",
      "Initializing NUTS using jitter+adapt_diag...\n",
      "Multiprocess sampling (2 chains in 2 jobs)\n",
      "NUTS: [error, slope, intercept]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='1000' class='' max='1000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      100.00% [1000/1000 00:00<00:00 Sampling 2 chains, 518 divergences]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n",
      "200\n",
      "100\n",
      "300\n",
      "400\n",
      "200\n",
      "500\n",
      "300\n",
      "400\n",
      "500\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The chain contains only diverging samples. The model is probably misspecified.\n",
      "The acceptance probability does not match the target. It is 0.0, but should be close to 0.8. Try to increase the number of tuning steps.\n",
      "There were 18 divergences after tuning. Increase `target_accept` or reparameterize.\n",
      "The acceptance probability does not match the target. It is 9.211751427765233e-155, but should be close to 0.8. Try to increase the number of tuning steps.\n",
      "The rhat statistic is larger than 1.4 for some parameters. The sampler did not converge.\n",
      "The estimated number of effective samples is smaller than 200 for some parameters.\n"
     ]
    }
   ],
   "source": [
    "def my_callback(trace, draw):\n",
    "    if len(trace) % 100 == 0:\n",
    "        print(len(trace))\n",
    "    \n",
    "with model:\n",
    "     trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=2, cores=2)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the `draw.chain` attribute to figure out which chain the current draw and trace belong to. Combined with some kind of convergence statistic like r_hat we can stop when we have converged, regardless of the amount of specified draws."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Auto-assigning NUTS sampler...\n",
      "Initializing NUTS using jitter+adapt_diag...\n",
      "Multiprocess sampling (2 chains in 2 jobs)\n",
      "NUTS: [error, slope, intercept]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "        <style>\n",
       "            /* Turns off some styling */\n",
       "            progress {\n",
       "                /* gets rid of default border in Firefox and Opera. */\n",
       "                border: none;\n",
       "                /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
       "                background-size: auto;\n",
       "            }\n",
       "            .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "                background: #F44336;\n",
       "            }\n",
       "        </style>\n",
       "      <progress value='3596' class='' max='202000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      1.78% [3596/202000 00:02<02:10 Sampling 2 chains, 37 divergences]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The estimated number of effective samples is smaller than 200 for some parameters.\n"
     ]
    }
   ],
   "source": [
    "import arviz as az\n",
    "class MyCallback:\n",
    "    def __init__(self, every=1000, max_rhat=1.05):\n",
    "        self.every = every\n",
    "        self.max_rhat = max_rhat\n",
    "        self.traces = {}\n",
    "    \n",
    "    def __call__(self, trace, draw):\n",
    "        if draw.tuning:\n",
    "            return\n",
    "\n",
    "        self.traces[draw.chain] = trace\n",
    "        if len(trace) % self.every == 0:    \n",
    "            multitrace = pm.backends.base.MultiTrace(list(self.traces.values()))\n",
    "            if pm.stats.rhat(multitrace).to_array().max() < self.max_rhat:\n",
    "                raise KeyboardInterrupt\n",
    "\n",
    "with model:\n",
    "     trace = pm.sample(tune=1000, draws=100000, callback=MyCallback(), chains=2, cores=2)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
